其他
【他山之石】Tensorflow之TFRecord的原理和使用心得
“他山之石,可以攻玉”,站在巨人的肩膀才能看得更高,走得更远。在科研的道路上,更需借助东风才能更快前行。为此,我们特别搜集整理了一些实用的代码链接,数据集,软件,编程技巧等,开辟“他山之石”专栏,助你乘风破浪,一路奋勇向前,敬请关注。
地址:https://zhuanlan.zhihu.com/p/352025069
01
02
message Example {
Features features = 1;
};
message SequenceExample {
Features context = 1;
FeatureLists feature_lists = 2;
};
message Features {
// Map from feature name to feature.
map<string, Feature> feature = 1;
};
// Containers for non-sequential data.
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
// Containers for sequential data.
//
// A FeatureList contains lists of Features. These may hold zero or more
// Feature values.
//
// FeatureLists are organized into categories by name. The FeatureLists message
// contains the mapping from name to FeatureList.
//
message FeatureList {
repeated Feature feature = 1;
};
message FeatureLists {
// Map from feature name to feature list.
map<string, FeatureList> feature_list = 1;
};
uint64 length
uint32 masked_crc32_of_length
byte data[length]
uint32 masked_crc32_of_data
03
import tensorflow as tf
# 回忆上一小节介绍的,每个Example内部实际有若干种Feature表达,下面
# 的四个工具方法方便我们进行Feature的构造
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _int64list_feature(value_list):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value_list))
# Example序列化成字节字符串
def serialize_example(user_id, city_id, app_type, viewd_pois, avg_paid, comment):
# 注意我们需要按照格式来进行数据的组装,这里的dict便按照指定Schema构造了一条Example
feature = {
'user_id': _int64_feature(user_id),
'city_id': _int64_feature(city_id),
'app_type': _int64_feature(app_type),
'viewd_pois': _int64list_feature(viewd_pois),
'avg_paid': _float_feature(avg_paid),
'comment': _bytes_feature(comment),
}
# 调用相关api将Example序列化为字节字符串
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
# 样本的生产,这里展示了将2条样本数据写入到了TFRecord文件中
def write_demo(filepath):
with tf.python_io.TFRecordWriter(filepath) as writer:
writer.write(serialize_example(1, 10, 1, [658, 325], 36.3, "yummy food."))
writer.write(serialize_example(2, 20, 2, [897, 568, 126], 89.6, "nice place to have dinner."))
print "write demo data done."
filepath = "testdata.tfrecord"
write_demo(filepath)
04
from pyspark.sql.types import *
def main():
#从hive表中读取数据
df=spark.sql("""
select * from experiment.table""")
#tfrecords保存路径
path = "viewfs:///user/hadoop-hdp/ml/demo/tensorflow/data/tfrecord"
#将spark DataFrame格式数据转换为tfrecords格式数据
df.repartition(file_num).write \
.mode("overwrite") \
.format("tfrecords") \
.option("recordType", "Example")\
.save(path)
if __name__ == "__main__":
main()
05
def read_demo(filepath):
# 定义schema
schema = {
'user_id': tf.FixedLenFeature([], tf.int64),
'city_id': tf.FixedLenFeature([], tf.int64),
'app_type': tf.FixedLenFeature([], tf.int64),
'viewed_pois': tf.VarLenFeature(tf.int64),
'avg_paid': tf.FixedLenFeature([], tf.float32, default_value=0.0),
'comment': tf.FixedLenFeature([], tf.string, default_value=''),
}
# 使用相关api,按照schema解析dataset中的样本
def _parse_function(example_proto):
return tf.parse_single_example(example_proto, schema)
# 读取TFRecord文件来创建dataset
dataset = tf.data.TFRecordDataset(filepath)
#按照schema解析dataset中的每个样本
parsed_dataset = dataset.map(_parse_function)
#创建Iterator并迭代Iterator即可访问dataset中的样本
next = parsed_dataset.make_one_shot_iterator().get_next()
# 这里直接利用session,打印dataset中的样本
with tf.Session() as sess:
while True:
try:
print sess.run(next)
except:
print "out of data"
break
tf.parse_single_example(
serialized,
features,
name=None,
example_names=None
)
serialized:序列化的Example。 features:一个字典,key是特征,value是FixedLenFeature/VarLenFeature/FixedSequenceFeature值。 name:此操作的名称(可选)。 example_names:(可选)标量字符串张量,关联的名称。
06
“他山之石”历史文章
从零开始实现一个卷积神经网络
斯坦福大规模网络数据集
超轻量的YOLO-Nano
MMAction2: 新一代视频理解工具箱
TensorFlow神经网络实现二分类的正确姿势
人类早期驯服野生机器学习模型的珍贵资料
不会强化学习,只会numpy,能解决多难的RL问题?
技术总结《OpenAI Gym》
ROC和CMC曲线的理解(FAR, FRR的理解)
pytorch使用hook打印中间特征图、计算网络算力等
Ray和Pytorch Lightning 使用指北
如何在科研论文中画出漂亮的插图?
PyTorch 源码解读之 torch.optim:优化算法接口详解
AI框架基础技术之深度学习中的通信优化
SimCLR:用于视觉表征的对比学习框架
更多他山之石专栏文章,
请点击文章底部“阅读原文”查看
分享、点赞、在看,给个三连击呗!